from torch.nn.functional import pad as tensor_pad
import torch


def tensor_divide(tensor, psize, overlap, pad=True):
    """
    Divide Tensor Into Blocks, Especially for Remainder
    :param tensor:
    :param psize:
    :param overlap:
    :return: List
    """
    B, C, H, W = tensor.shape

    # Pad to number that can be divisible
    if pad:
        h_pad = psize - H % psize if H % psize != 0 else 0
        w_pad = psize - W % psize if W % psize != 0 else 0
        H += h_pad
        W += w_pad
        if h_pad != 0 or w_pad != 0:
            tensor = tensor_pad(tensor, (0, w_pad, 0,h_pad), mode='reflect').data

    h_block = H // psize
    w_block = W // psize
    blocks = []
    if overlap != 0:
        tensor = tensor_pad(tensor, (overlap, overlap, overlap, overlap), mode='reflect').data

    for i in range(h_block):
        for j in range(w_block):
            end_h = tensor.shape[2] if i + 1 == h_block else (i + 1) * psize + 2 * overlap
            end_w = tensor.shape[3] if j + 1 == w_block else (j + 1) * psize + 2 * overlap
            # end_h = (i + 1) * psize + 2 * overlap
            # end_w = (j + 1) * psize + 2 * overlap
            part = tensor[:, :, i * psize: end_h, j * psize: end_w]
            blocks.append(part)
    return blocks


def tensor_merge(blocks, tensor, psize, overlap, pad=True, tensor_shape=None):
    """
    Combine many small patch into one big Image
    :param blocks: List of 4D Tensors or just a 4D Tensor
    :param tensor:  has the same size as the big image
    :param psize:
    :param overlap:
    :return: Tensor
    """
    if tensor_shape is None:
        B, C, H, W = tensor.shape
    else:
        B, C, H, W = tensor_shape

    # Pad to number that can be divisible
    if pad:
        h_pad = psize - H % psize if H % psize != 0 else 0
        w_pad = psize - W % psize if W % psize != 0 else 0
        H += h_pad
        W += w_pad

    tensor_new = torch.FloatTensor(B, C, H, W)
    h_block = H // psize
    w_block = W // psize
    # print(tensor.shape, tensor_new.shape)
    for i in range(h_block):
        for j in range(w_block):
            end_h = tensor_new.shape[2] if i + 1 == h_block else (i + 1) * psize
            end_w = tensor_new.shape[3] if j + 1 == w_block else (j + 1) * psize
            # end_h = (i + 1) * psize
            # end_w = (j + 1) * psize
            part = blocks[i * w_block + j]

            if len(part.shape) < 4:
                part = part.unsqueeze(0)

            tensor_new[:, :, i * psize: end_h, j * psize: end_w] = \
                part[:, :, overlap: part.shape[2] - overlap, overlap: part.shape[3] - overlap]

    # Remove Pad Edges
    if tensor_shape is None:
        B, C, H, W = tensor.shape
    else:
        B, C, H, W = tensor_shape
    tensor_new = tensor_new[:, :, :H, :W]
    return tensor_new